﻿using System;
using System.Collections;
using System.Collections.Generic;

namespace IOP.Models.Tree.BinaryTree
{
    /// <summary>
    /// 红黑树
    /// </summary>
    public class RedBlackTree<TKey, TValue> : IBinaryTree<TKey, TValue>
        where TKey : IComparable<TKey>
    {
        /// <summary>
        /// 红节点
        /// </summary>
        public const bool RED = true;
        /// <summary>
        /// 黑节点
        /// </summary>
        public const bool BLACK = false;

        /// <summary>
        /// 根节点
        /// </summary>
        public RedBlackTreeNode<TKey, TValue> Root { get; set; } = null;

        /// <summary>
        /// 索引
        /// </summary>
        /// <param name="key"></param>
        /// <returns></returns>
        public virtual TValue this[TKey key] => Get(key);

        /// <summary>
        /// 数量
        /// </summary>
        public int Count { get; set; } = 0;

        /// <summary>
        /// 是否包含某个值
        /// </summary>
        /// <param name="key"></param>
        /// <param name="result"></param>
        /// <returns></returns>
        public virtual bool Contains(TKey key, out TValue result)
        {
            if (Root == null) throw new NullReferenceException("No data in this tree");
            var c = Get(Root, key);
            if (c.Key.CompareTo(key) != 0)
            {
                result = default;
                return false;
            }
            result = c.Value;
            return true;
        }

        /// <summary>
        /// 添加
        /// </summary>
        /// <param name="key"></param>
        /// <param name="value"></param>
        public virtual void Put(TKey key, TValue value)
        {
            if (Root == null)
            {
                Root = new RedBlackTreeNode<TKey, TValue>(key, value, BLACK);
                Count++;
            }
            else
            {
                int cmp;
                RedBlackTreeNode<TKey, TValue> parent = null;
                RedBlackTreeNode<TKey, TValue> current = Root;
                while(current != null)
                {
                    cmp = key.CompareTo(current.Key);
                    if(cmp == -1)
                    {
                        parent = current;
                        current = current.Left;
                        continue;
                    }else if(cmp == 0)
                    {
                        current.Value = value;
                        return;
                    }
                    else
                    {
                        parent = current;
                        current = current.Right;
                        continue;
                    }
                }
                current = new RedBlackTreeNode<TKey, TValue>(key, value);
                current.Parent = parent;
                cmp = key.CompareTo(parent.Key);
                if (cmp < 0) parent.Left = current;
                else parent.Right = current;
                Count++;
                PutFixUp(current);
            }
        }
        /// <summary>
        /// 插入修复函数
        /// </summary>
        private void PutFixUp(RedBlackTreeNode<TKey,TValue> current)
        {
            if (current == null) return;
            while (current != Root && GetColor(GetParent(current)) == RED)
            {
                //父节点位于祖父节点的左节点
                if (GetParent(current) == GetLeft(GetParent(GetParent(current))))
                {
                    RedBlackTreeNode<TKey, TValue> uncle = GetRight(GetParent(GetParent(current)));
                    //case 1 叔节点为红色
                    if (GetColor(uncle) == RED)
                    {
                        SetColor(uncle, BLACK);
                        SetColor(GetParent(current), BLACK);
                        SetColor(GetParent(GetParent(current)), RED);
                        current = GetParent(GetParent(current));
                    }
                    else
                    {
                        //case 2 当前节点位于父节点右边
                        if (current == GetRight(GetParent(current)))
                        {
                            current = GetParent(current);
                            LeftRotation(current);
                        }
                        //case 3 未左旋或者左旋后当前节点位于父节点左边
                        SetColor(GetParent(current), BLACK);
                        SetColor(GetParent(GetParent(current)), RED);
                        RightRotation(GetParent(GetParent(current)));
                    }
                }
                else
                {
                    //父节点位于祖父节点的右节点
                    RedBlackTreeNode<TKey, TValue> uncle = GetLeft(GetParent(GetParent(current)));
                    //case 4 叔节点为红色
                    if(GetColor(uncle) == RED)
                    {
                        SetColor(GetParent(current), BLACK);
                        SetColor(uncle, BLACK);
                        SetColor(GetParent(GetParent(current)), RED);
                        current = GetParent(GetParent(current));
                    }
                    else
                    {
                        //case 5 当前节点位于父节点左边
                        if(current == GetLeft(GetParent(current)))
                        {
                            current = GetParent(current);
                            RightRotation(current);
                        }
                        //case 6 未右旋或者右旋后当前节点位于父节点右边
                        SetColor(GetParent(current), BLACK);
                        SetColor(GetParent(GetParent(current)), RED);
                        LeftRotation(GetParent(GetParent(current)));
                    }
                }
            }
            SetColor(Root, BLACK);
        }

        /// <summary>
        /// 删除
        /// </summary>
        /// <param name="key"></param>
        public virtual void Delete(TKey key)
        {
            var entry = Get(Root, key);
            if (entry.Key.CompareTo(key) != 0) return;
            Count--;
            Delete(entry);
        }
        /// <summary>
        /// 删除
        /// </summary>
        /// <param name="current"></param>
        private void Delete(RedBlackTreeNode<TKey, TValue> current)
        {
            if (current == null) return;
            if (current.Left != null && current.Right != null)
            {
                var min = FindMin(current.Right);
                current.Key = min.Key;
                current.Value = min.Value;
                current = min;
            }
            var replace = current.Left != null ? current.Left : current.Right;
            if (replace != null)
            {
                replace.Parent = current.Parent;
                if (current.Parent == null) Root = replace;
                else if (current == current.Parent.Left) current.Parent.Left = replace;
                else current.Parent.Right = replace;
                current.Left = current.Right = current.Parent = null;
                if (current.Color == BLACK) FixAfterDeletion(replace);
            }
            else if(current.Parent == null && replace == null)
            {
                Root = null;
                return;
            }
            else
            {
                if (current.Color == BLACK) FixAfterDeletion(current);
                if (current == current.Parent.Left) current.Parent.Left = null;
                else current.Parent.Right = null;
            }
        }
        /// <summary>
        /// 删除后修补函数
        /// </summary>
        /// <param name="current"></param>
        private void FixAfterDeletion(RedBlackTreeNode<TKey, TValue> current)
        {
            while(current != null && current != Root && GetColor(current) == BLACK)
            {
                if (current == GetLeft(GetParent(current)))
                {
                    var brother = GetRight(GetParent(current));
                    //case 1:兄弟节点是红色时
                    if (GetColor(brother) == RED)
                    {
                        SetColor(brother.Left, BLACK);
                        SetColor(GetParent(current), RED);
                        LeftRotation(GetParent(current));
                        brother = GetRight(GetParent(current));
                    }
                    //case 2:兄弟节点的子节点颜色都为黑色
                    if (GetColor(GetLeft(brother)) == BLACK && GetColor(GetRight(brother)) == BLACK)
                    {
                        SetColor(brother, RED);
                        current = GetParent(current);
                    }
                    else
                    {
                        //case 3 : 兄弟节点的右节点颜色为黑色，左节点为红色
                        if (GetColor(GetRight(brother)) == BLACK)
                        {
                            SetColor(GetLeft(brother), BLACK);
                            SetColor(brother, RED);
                            RightRotation(brother);
                            brother = GetRight(GetParent(current));
                        }
                        //case 4 : 兄弟节点的左节点颜色为黑色，右节点为红色或者左右都为红色
                        SetColor(brother, GetColor(GetParent(current)));
                        SetColor(GetParent(current), BLACK);
                        SetColor(GetRight(brother), BLACK);
                        LeftRotation(GetParent(current));
                        current = Root;
                    }
                }
                else
                {
                    var brother = GetLeft(GetParent(current));
                    //case 5:兄弟节点是红色时
                    if(GetColor(brother) == RED)
                    {
                        SetColor(brother, BLACK);
                        SetColor(GetParent(current), RED);
                        RightRotation(GetParent(current));
                        brother = GetLeft(GetParent(current));
                    }
                    //case 6:兄弟节点的子节点都为黑色
                    if(GetColor(GetLeft(brother)) == BLACK && GetColor(GetRight(brother)) == BLACK)
                    {
                        SetColor(brother, RED);
                        current = GetParent(current);
                    }
                    else
                    {
                        //case 7 :兄弟节点的左节点为黑色，右节点为红色
                        if(GetColor(GetLeft(brother))== BLACK)
                        {
                            SetColor(GetRight(brother), BLACK);
                            SetColor(brother, RED);
                            LeftRotation(brother);
                            brother = GetLeft(GetParent(current));
                        }
                        //case8 :兄弟节点的右节点为红色，左节点为黑色或者左右节点都为红色
                        SetColor(brother, GetColor(GetParent(current)));
                        SetColor(GetParent(current), BLACK);
                        SetColor(GetLeft(brother), BLACK);
                        RightRotation(GetParent(current));
                        current = Root;
                    }
                }
            }
            SetColor(current, BLACK);
        }

        /// <summary>
        /// 清理
        /// </summary>
        public virtual void Clear()
        {
            Root = null;
            Count = 0;
        }

        /// <summary>
        /// 获取节点的颜色
        /// </summary>
        /// <param name="current"></param>
        /// <returns></returns>
        private bool GetColor(RedBlackTreeNode<TKey, TValue> current)
        {
            if (current == null) return BLACK;
            else return current.Color;
        }
        /// <summary>
        /// 设置节点颜色
        /// </summary>
        /// <param name="current"></param>
        /// <param name="color"></param>
        /// <returns></returns>
        private void SetColor(RedBlackTreeNode<TKey, TValue> current, bool color)
        {
            if (current == null) return;
            else current.Color = color;
        }
        /// <summary>
        /// 获取当前节点的父节点
        /// </summary>
        /// <param name="current"></param>
        /// <returns></returns>
        private RedBlackTreeNode<TKey, TValue> GetParent(RedBlackTreeNode<TKey, TValue> current)
        {
            if (current == null) return null;
            else return current.Parent;
        }
        /// <summary>
        /// 获取左子节点
        /// </summary>
        /// <param name="current"></param>
        /// <returns></returns>
        private RedBlackTreeNode<TKey, TValue> GetLeft(RedBlackTreeNode<TKey, TValue> current)
        {
            if (current == null) return null;
            else return current.Left;
        }
        /// <summary>
        /// 获取右子节点
        /// </summary>
        /// <param name="current"></param>
        /// <returns></returns>
        private RedBlackTreeNode<TKey, TValue> GetRight(RedBlackTreeNode<TKey, TValue> current)
        {
            if (current == null) return null;
            else return current.Right;
        }
        /// <summary>
        /// 左旋
        /// </summary>
        /// <param name="current"></param>
        private void LeftRotation(RedBlackTreeNode<TKey, TValue> current)
        {
            if (current == null) return;
            RedBlackTreeNode<TKey, TValue> temp = current.Right;
            current.Right = temp.Left;

            if (temp.Left != null) temp.Left.Parent = current;
            temp.Parent = current.Parent;

            if (current.Parent == null) Root = temp;
            else if (current.Parent.Left == current) current.Parent.Left = temp;
            else current.Parent.Right = temp;
            temp.Left = current;
            current.Parent = temp;

        }
        /// <summary>
        /// 右旋
        /// </summary>
        /// <param name="current"></param>
        /// <returns></returns>
        private void RightRotation(RedBlackTreeNode<TKey, TValue> current)
        {
            if (current == null) return;
            RedBlackTreeNode<TKey, TValue> temp = current.Left;
            current.Left = temp.Right;

            if (temp.Right != null) temp.Right.Parent = current;
            temp.Parent = current.Parent;

            if (current.Parent == null) Root = temp;
            else if (current.Parent.Right == current) current.Parent.Right = temp;
            else current.Parent.Left = temp;
            temp.Right = current;
            current.Parent = temp;
        }

        /// <summary>
        /// 查找最大值
        /// </summary>
        /// <returns></returns>
        public virtual TValue FindMax()
        {
            var max = FindMax(Root);
            return max.Value;
        }
        /// <summary>
        /// 查找最大值(递归)
        /// </summary>
        /// <param name="current"></param>
        /// <returns></returns>
        private RedBlackTreeNode<TKey, TValue> FindMax(RedBlackTreeNode<TKey, TValue> current)
        {
            if (current.Right == null) return current; ;
            return FindMax(current.Right);
        }

        /// <summary>
        /// 查找最小值
        /// </summary>
        /// <returns></returns>
        public virtual TValue FindMin()
        {
            var min = FindMin(Root);
            return min.Value;
        }
        /// <summary>
        /// 查找最小值(递归)
        /// </summary>
        /// <param name="current"></param>
        /// <returns></returns>
        private RedBlackTreeNode<TKey, TValue> FindMin(RedBlackTreeNode<TKey, TValue> current)
        {
            if (current.Left == null) return current;
            return FindMin(current.Left);
        }

        /// <summary>
        /// 搜索
        /// </summary>
        /// <param name="key"></param>
        /// <returns></returns>
        public virtual TValue Get(TKey key)
        {
            if (Root == null) throw new NullReferenceException("No data in this tree");
            var c = Get(Root, key);
            if (c.Key.CompareTo(key) != 0) return default;
            return c.Value;
        }
        /// <summary>
        /// 搜索递归函数
        /// </summary>
        /// <param name="current">当前节点</param>
        /// <param name="key">键</param>
        /// <returns></returns>
        private RedBlackTreeNode<TKey, TValue> Get(RedBlackTreeNode<TKey, TValue> current, TKey key)
        {
            int cmp = key.CompareTo(current.Key);
            if (cmp == 0) return current;
            else if (cmp == -1)
            {
                if (current.Left == null) return current;
                return Get(current.Left, key);
            }
            else
            {
                if (current.Right == null) return current;
                return Get(current.Right, key);
            }
        }

        /// <summary>
        /// 先序遍历
        /// </summary>
        /// <returns></returns>
        public virtual IEnumerable<Node<TKey, TValue>> PreOrder()
        {
            if (Root == null) return new List<Node<TKey, TValue>>();
            else return PreOrder(Root);
        }
        /// <summary>
        /// 先序遍历递归方法
        /// </summary>
        /// <param name="root"></param>
        /// <returns></returns>
        private IEnumerable<Node<TKey, TValue>> PreOrder(RedBlackTreeNode<TKey, TValue> root)
        {
            List<Node<TKey, TValue>> elements = new List<Node<TKey, TValue>>();
            if (root != null)
            {
                elements.Add(root);
                elements.AddRange(PreOrder(root.Left));
                elements.AddRange(PreOrder(root.Right));
            }
            return elements;
        }

        /// <summary>
        /// 中序遍历
        /// </summary>
        /// <returns></returns>
        public virtual IEnumerable<Node<TKey, TValue>> InOrder()
        {
            if (Root == null) return new List<Node<TKey, TValue>>();
            else return InOrder(Root);
        }
        /// <summary>
        /// 中序遍历递归方法
        /// </summary>
        /// <param name="root"></param>
        /// <returns></returns>
        private IEnumerable<Node<TKey, TValue>> InOrder(RedBlackTreeNode<TKey, TValue> root)
        {
            List<Node<TKey, TValue>> elements = new List<Node<TKey, TValue>>();
            if (root != null)
            {
                elements.AddRange(InOrder(root.Left));
                elements.Add(root);
                elements.AddRange(InOrder(root.Right));
            }
            return elements;
        }

        /// <summary>
        /// 后序遍历
        /// </summary>
        /// <returns></returns>
        public virtual IEnumerable<Node<TKey, TValue>> PostOrder()
        {
            if (Root == null) return new List<Node<TKey, TValue>>();
            else return PostOrder(Root);
        }
        /// <summary>
        /// 后序遍历递归方法
        /// </summary>
        /// <param name="root"></param>
        /// <returns></returns>
        private IEnumerable<Node<TKey, TValue>> PostOrder(RedBlackTreeNode<TKey, TValue> root)
        {
            List<Node<TKey, TValue>> elements = new List<Node<TKey, TValue>>();
            if (root != null)
            {
                elements.AddRange(PostOrder(root.Left));
                elements.AddRange(PostOrder(root.Right));
                elements.Add(root);
            }
            return elements;
        }

        /// <summary>
        /// 实现IEnumerable接口
        /// </summary>
        /// <returns></returns>
        public virtual IEnumerator<Node<TKey, TValue>> GetEnumerator()
        {
            foreach (var item in InOrder())
            {
                yield return item;
            }
        }

        /// <summary>
        /// 实现IEnumerable接口
        /// </summary>
        /// <returns></returns>
        IEnumerator IEnumerable.GetEnumerator()
        {
            foreach (var item in InOrder())
            {
                yield return item;
            }
        }
    }
}
